# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test to demonstrate custom training loop with ParameterServerStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import multiprocessing

import portpicker
import tensorflow.compat.v2 as tf
from absl import logging

NUM_EPOCHS = 10
NUM_STEPS = 100
STEPS_PER_EXECUTION = 10


class ParameterServerCustomTrainingLoopTest(tf.test.TestCase):
    """Test to demonstrate custom training loop with ParameterServerStrategy."""

    def create_in_process_cluster(self, num_workers, num_ps):
        """Creates and starts local servers and returns the cluster_resolver."""
        worker_ports = [
            portpicker.pick_unused_port() for _ in range(num_workers)
        ]
        ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]

        cluster_dict = {}
        cluster_dict["worker"] = [f"localhost:{port}" for port in worker_ports]
        if num_ps > 0:
            cluster_dict["ps"] = [f"localhost:{port}" for port in ps_ports]

        cluster_spec = tf.train.ClusterSpec(cluster_dict)

        # Workers need some inter_ops threads to work properly.
        worker_config = tf.compat.v1.ConfigProto()
        if multiprocessing.cpu_count() < num_workers + 1:
            worker_config.inter_op_parallelism_threads = num_workers + 1

        for i in range(num_workers):
            tf.distribute.Server(
                cluster_spec,
                job_name="worker",
                task_index=i,
                config=worker_config,
                protocol="grpc",
            )

        for i in range(num_ps):
            tf.distribute.Server(
                cluster_spec, job_name="ps", task_index=i, protocol="grpc"
            )

        return cluster_spec

    def setUp(self):
        super().setUp()

        cluster_spec = self.create_in_process_cluster(num_workers=3, num_ps=2)
        cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
            cluster_spec, rpc_layer="grpc"
        )
        self.strategy = tf.distribute.experimental.ParameterServerStrategy(
            cluster_resolver
        )
        self.coordinator = (
            tf.distribute.experimental.coordinator.ClusterCoordinator(
                self.strategy
            )
        )

    def testCustomTrainingLoop(self):

        coordinator, strategy = self.coordinator, self.strategy

        def per_worker_dataset_fn():
            def dataset_fn(_):
                return (
                    tf.data.Dataset.from_tensor_slices(
                        (tf.random.uniform((6, 10)), tf.random.uniform((6, 10)))
                    )
                    .batch(2)
                    .repeat()
                )

            return strategy.distribute_datasets_from_function(dataset_fn)

        per_worker_dataset = coordinator.create_per_worker_dataset(
            per_worker_dataset_fn
        )
        with strategy.scope():
            model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
            optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
            train_accuracy = tf.keras.metrics.CategoricalAccuracy(
                name="train_accuracy"
            )

        @tf.function
        def worker_train_fn(iterator):
            def replica_fn(inputs):
                """Training loop function."""
                batch_data, labels = inputs
                with tf.GradientTape() as tape:
                    predictions = model(batch_data, training=True)
                    loss = tf.keras.losses.CategoricalCrossentropy(
                        reduction=tf.keras.losses.Reduction.NONE
                    )(labels, predictions)
                gradients = tape.gradient(loss, model.trainable_variables)

                optimizer.apply_gradients(
                    zip(gradients, model.trainable_variables)
                )
                train_accuracy.update_state(labels, predictions)

            for _ in tf.range(STEPS_PER_EXECUTION):
                strategy.run(replica_fn, args=(next(iterator),))

        for epoch in range(NUM_EPOCHS):

            distributed_iterator = iter(per_worker_dataset)

            for step in range(0, NUM_STEPS, STEPS_PER_EXECUTION):
                coordinator.schedule(
                    worker_train_fn, args=(distributed_iterator,)
                )
                logging.info("Epoch %d, step %d scheduled.", epoch, step)

            logging.info("Now joining at epoch %d.", epoch)
            coordinator.join()
            logging.info(
                "Finished joining at epoch %d. Training accuracy: %f. "
                "Total iterations: %d",
                epoch,
                train_accuracy.result(),
                optimizer.iterations.value(),
            )

            if epoch < NUM_EPOCHS - 1:
                train_accuracy.reset_states()


if __name__ == "__main__":
    if tf.__internal__.tf2.enabled():
        tf.__internal__.distribute.multi_process_runner.test_main()
